Add Megatron-LM cross-entropy integration#1207
Conversation
ed3c27e to
b1fa5bc
Compare
b1fa5bc to
41362ee
Compare
41362ee to
e4b2ff2
Compare
Mecoli1219
left a comment
There was a problem hiding this comment.
Overall looks great! Excited to support Megatron with Liger. Left some comments to address.
| if tp_size > 1: | ||
| raise RuntimeError( | ||
| f"apply_liger_kernel_to_megatron currently requires tensor_model_parallel_size=1, " | ||
| f"got {tp_size}. Vocab-parallel cross-entropy support is planned as follow-up work." | ||
| ) |
There was a problem hiding this comment.
This is a constrain that need to be addressed in the future given that TP is a common use case in Megatron, but it's a great start supporting megatron!
BTW, does this patching also not support other parallel strategy? (Sequence Parallel, etc)
There was a problem hiding this comment.
It feels a bit awkward to me to have patching and function wrapping logics in liger side. Surely it is a simpler way to use liger's ce without touching megatron codebase. However, if supporting megatron framework is not in our roadmap, and not going to add it to our test suite in a short time, it will be quite inconvenient to maintain this support in a long run. WDYT?
There was a problem hiding this comment.
BTW, megatron's SP requires TP>1
| global _ACTIVATION_LOGGED | ||
| if not _ACTIVATION_LOGGED: |
| return liger_fused_vocab_parallel_cross_entropy | ||
|
|
||
|
|
||
| def apply_liger_kernel_to_megatron( |
There was a problem hiding this comment.
Can we move it to another file like monkey_patch.py under the same directory? If we want to add more kernel besides CE, it would be cleaner to separate the framework-level and kernel-specific logic. You can mirror src/liger_kernel/trainsformers/:
src/liger_kernel/metatron/
monkey_patch.py # apply_liger_kernel_to_megatron + TP check
cross_entropy.py # _build_wrapper + _patch_fused_vocab_parallel_ce
other_future_kernel.pys
Summary
Adds `apply_liger_kernel_to_megatron()` monkey-patch that swaps Megatron-LM's native `fused_vocab_parallel_cross_entropy` for Liger's Triton cross-entropy kernel.Enables online softmax + in-place gradients + no full-softmax materialization inside Megatron training pipelines.
Scope:
tensor_model_parallel_size=1only. With TP>1, each rank holds a sharded[N, V/tp]logits slice and CE requires cross-rank all-reduces that Liger's kernel does not perform.The patch raises
RuntimeErrorat patch time (viamegatron.core.parallel_state) and again at call time (via thetp_groupargument Megatron passes), so misconfiguration fails loudly. Vocab-parallel support is follow-up work.Tested on Qwen3-30B-A3B scaled MoE, 1× H100_8, BF16:
Model config:
Parallelism:
Training config:
Throughput results:
| Throughput | Iter time
Megatron native fused CE (baseline) | ~99 TFLOP/s/GPU | ~39,400 ms
Liger CE (this PR) | ~108 (+9%) | ~35,900 ms
Numerical correctness: lm_loss ~4.1e-3 in both, no NaN/skipped iterations.
Variance: Liger CE 107.7-109.1 TFLOP/s/GPU (consistent).
Test setup: Single H100 80GB, sequence length S=2048, batch size B=4, vocab sizes 4K → 131K. Each provider is the same cross-entropy operation, just different implementations:
Testing Done
make testto ensure correctnessmake checkstyleto ensure code stylemake test-convergenceto ensure convergence